import numpy as np
import pandas as pd
import statsmodels.api as sm
import math
import sys

## Linear Estimator
def regEstimate(dataset, pbvarb = 'predicted_prob_black', outcome = 'audited'):
        model = sm.OLS.from_formula(outcome+' ~ '+pbvarb, data = dataset).fit(cov_type = 'HC1')
        coef = model.params[pbvarb]
        se = model.bse[pbvarb]
        black_aud_rate = model.params[0] + model.params[1]
        non_black_aud_rate = model.params[0] 
        print("number of obs :" + str(model.nobs))
        return coef, se, black_aud_rate, non_black_aud_rate

## Linear Estimator with controls
def regEstimate_control(dataset, pbvarb = 'predicted_prob_black', controls = '', outcome = 'audited'):
        model = sm.OLS.from_formula(outcome+' ~ '+pbvarb+' + '+controls, data = dataset).fit(cov_type = 'HC1')
        coef = model.params[pbvarb]
        se = model.bse[pbvarb]
        results = model.summary()
        print("number of obs :"+ str(model.nobs))
        return coef, se

## Probabilistic Estimator 
def chenEstimate(dataset, pbvarb = 'predicted_prob_black', outcome = 'audited'):
        black_aud_rate = (dataset[pbvarb]*dataset[outcome]).sum()/(dataset[pbvarb]).sum()
        non_black_aud_rate = ((1-dataset[pbvarb])*dataset[outcome]).sum()/(1-dataset[pbvarb]).sum()
        est =  black_aud_rate - non_black_aud_rate
        return est, black_aud_rate, non_black_aud_rate

## Get standard errors
def getSEs(dataset,pbvarb='predicted_prob_black',outcome='audited', seReg = 0.0):
        seMultiplier=getSEMultiplier(dataset,pbvarb)
        #seReg = regEstimate(dataset,pbvarb,outcome)[1]
        seChen = seReg*seMultiplier
        return seReg, seChen

def getSEMultiplier(dataset,pbvarb='predicted_prob_black'):
        return np.sqrt(dataset[pbvarb].var()/(dataset[pbvarb].mean()*(1-dataset[pbvarb].mean())))


### Operational audits and predicted race probabilitiess
cols = ['isEIC', 'predicted_prob_black', 'pre_refund', 'post_ref', 'corr_aud', 'non_corr_aud', 'aud_no_research_audits', 'taxpayer_id_new']
dataBISG = pd.read_csv("/REDACTED/data/final/individualBISG2014_full_final_new_eic.csv", usecols=cols)
len(dataBISG)
dataBISG= dataBISG[~dataBISG.predicted_prob_black.isna()]
len(dataBISG)

eitc_pop = dataBISG.loc[dataBISG['isEIC'] == 1]
non_eitc_pop = dataBISG.loc[dataBISG['isEIC'] == 0]


# loop over dataset 100times, bootstrapping the data on each iteration and calculataxpayer_idg outcomes of interest

full_pop_black_anys = []
full_pop_black_pres = []
full_pop_black_posts = []
full_pop_black_corrs = []
full_pop_black_ncorrs = []

full_pop_non_black_anys = []
full_pop_non_black_pres = []
full_pop_non_black_posts = []
full_pop_non_black_corrs = []
full_pop_non_black_ncorrs = []

full_pop_ns = []

eitc_pop_black_anys = []
eitc_pop_black_pres = []
eitc_pop_black_posts = []
eitc_pop_black_corrs = []
eitc_pop_black_ncorrs = []

eitc_pop_non_black_anys = []
eitc_pop_non_black_pres = []
eitc_pop_non_black_posts = []
eitc_pop_non_black_corrs = []
eitc_pop_non_black_ncorrs = []

eitc_pop_ns = []


non_eitc_pop_black_anys = []
non_eitc_pop_black_pres = []
non_eitc_pop_black_posts = []
non_eitc_pop_black_corrs = []
non_eitc_pop_black_ncorrs = []

non_eitc_pop_non_black_anys = []
non_eitc_pop_non_black_pres = []
non_eitc_pop_non_black_posts = []
non_eitc_pop_non_black_corrs = []
non_eitc_pop_non_black_ncorrs = []

non_eitc_pop_ns = []

from timeit import default_timer as timer
n_iter_bootstrap = 100
for i in range(n_iter_bootstrap):
    start = timer()
    print("Bootstrap iteration: " +str(i))
    # bootstrap all three datasets
    dataBISG_samp = dataBISG.sample(frac = 1, random_state=i, replace=True)
    eitc_samp = eitc_pop.sample(frac = 1, random_state=i, replace=True)
    non_eitc_samp = non_eitc_pop.sample(frac = 1, random_state=i, replace=True)
    # full population metrics
    full_pop_lin_any = regEstimate(dataBISG_samp, 'predicted_prob_black', 'aud_no_research_audits')
    full_pop_lin_pre = regEstimate(dataBISG_samp, 'predicted_prob_black', 'pre_refund')
    full_pop_lin_post = regEstimate(dataBISG_samp, 'predicted_prob_black', 'post_ref')
    full_pop_lin_corr = regEstimate(dataBISG_samp, 'predicted_prob_black', 'corr_aud')
    full_pop_lin_ncorr = regEstimate(dataBISG_samp, 'predicted_prob_black', 'non_corr_aud')
    full_pop_n = len(dataBISG_samp)
    # eitc population metrics
    eitc_pop_lin_any = regEstimate(eitc_samp, 'predicted_prob_black', 'aud_no_research_audits')
    eitc_pop_lin_pre = regEstimate(eitc_samp, 'predicted_prob_black', 'pre_refund')
    eitc_pop_lin_post = regEstimate(eitc_samp, 'predicted_prob_black', 'post_ref')
    eitc_pop_lin_corr = regEstimate(eitc_samp, 'predicted_prob_black', 'corr_aud')
    eitc_pop_lin_ncorr = regEstimate(eitc_samp, 'predicted_prob_black', 'non_corr_aud')
    eitc_pop_n = len(eitc_samp)
    # non-eitc population metrics
    non_eitc_pop_lin_any = regEstimate(non_eitc_samp, 'predicted_prob_black', 'aud_no_research_audits')
    non_eitc_pop_lin_pre = regEstimate(non_eitc_samp, 'predicted_prob_black', 'pre_refund')
    non_eitc_pop_lin_post = regEstimate(non_eitc_samp, 'predicted_prob_black', 'post_ref')
    non_eitc_pop_lin_corr = regEstimate(non_eitc_samp, 'predicted_prob_black', 'corr_aud')
    non_eitc_pop_lin_ncorr = regEstimate(non_eitc_samp, 'predicted_prob_black', 'non_corr_aud')
    non_eitc_pop_n = len(non_eitc_samp)
    # append results to the lists created before the for loop
    full_pop_black_anys.append(full_pop_lin_any[2]*100)
    full_pop_black_pres.append(full_pop_lin_pre[2]*100)
    full_pop_black_posts.append(full_pop_lin_post[2]*100)
    full_pop_black_corrs.append(full_pop_lin_corr[2]*100)
    full_pop_black_ncorrs.append(full_pop_lin_ncorr[2]*100)
    full_pop_non_black_anys.append(full_pop_lin_any[3]*100)
    full_pop_non_black_pres.append(full_pop_lin_pre[3]*100)
    full_pop_non_black_posts.append(full_pop_lin_post[3]*100)
    full_pop_non_black_corrs.append(full_pop_lin_corr[3]*100)
    full_pop_non_black_ncorrs.append(full_pop_lin_ncorr[3]*100)
    full_pop_ns.append(full_pop_n)
    eitc_pop_black_anys.append(eitc_pop_lin_any[2]*100)
    eitc_pop_black_pres.append(eitc_pop_lin_pre[2]*100)
    eitc_pop_black_posts.append(eitc_pop_lin_post[2]*100)
    eitc_pop_black_corrs.append(eitc_pop_lin_corr[2]*100)
    eitc_pop_black_ncorrs.append(eitc_pop_lin_ncorr[2]*100)
    eitc_pop_non_black_anys.append(eitc_pop_lin_any[3]*100)
    eitc_pop_non_black_pres.append(eitc_pop_lin_pre[3]*100)
    eitc_pop_non_black_posts.append(eitc_pop_lin_post[3]*100)
    eitc_pop_non_black_corrs.append(eitc_pop_lin_corr[3]*100)
    eitc_pop_non_black_ncorrs.append(eitc_pop_lin_ncorr[3]*100)
    eitc_pop_ns.append(eitc_pop_n)
    non_eitc_pop_black_anys.append(non_eitc_pop_lin_any[2]*100)
    non_eitc_pop_black_pres.append(non_eitc_pop_lin_pre[2]*100)
    non_eitc_pop_black_posts.append(non_eitc_pop_lin_post[2]*100)
    non_eitc_pop_black_corrs.append(non_eitc_pop_lin_corr[2]*100)
    non_eitc_pop_black_ncorrs.append(non_eitc_pop_lin_ncorr[2]*100)
    non_eitc_pop_non_black_anys.append(non_eitc_pop_lin_any[3]*100)
    non_eitc_pop_non_black_pres.append(non_eitc_pop_lin_pre[3]*100)
    non_eitc_pop_non_black_posts.append(non_eitc_pop_lin_post[3]*100)
    non_eitc_pop_non_black_corrs.append(non_eitc_pop_lin_corr[3]*100)
    non_eitc_pop_non_black_ncorrs.append(non_eitc_pop_lin_ncorr[3]*100)
    non_eitc_pop_ns.append(non_eitc_pop_n)
    # zip results into dataframes, write out to csvs
    full_pop_disp_df = pd.DataFrame(list(zip(full_pop_black_anys, full_pop_black_pres, full_pop_black_posts, full_pop_black_corrs, full_pop_black_ncorrs, full_pop_non_black_anys, full_pop_non_black_pres, full_pop_non_black_posts, full_pop_non_black_corrs, full_pop_non_black_ncorrs, full_pop_ns)), columns = ['black_any', 'black_pre', 'black_post', 'black_corr', 'black_ncorr', 'non_black_any', 'non_black_pre', 'non_black_post', 'non_black_corr', 'non_black_ncorr', 'n'])
    eitc_pop_disp_df = pd.DataFrame(list(zip(eitc_pop_black_anys, eitc_pop_black_pres, eitc_pop_black_posts, eitc_pop_black_corrs, eitc_pop_black_ncorrs, eitc_pop_non_black_anys, eitc_pop_non_black_pres, eitc_pop_non_black_posts, eitc_pop_non_black_corrs, eitc_pop_non_black_ncorrs, eitc_pop_ns)), columns = ['black_any', 'black_pre', 'black_post', 'black_corr', 'black_ncorr', 'non_black_any', 'non_black_pre', 'non_black_post', 'non_black_corr', 'non_black_ncorr', 'n'])
    non_eitc_pop_disp_df = pd.DataFrame(list(zip(non_eitc_pop_black_anys, non_eitc_pop_black_pres, non_eitc_pop_black_posts, non_eitc_pop_black_corrs, non_eitc_pop_black_ncorrs, non_eitc_pop_non_black_anys, non_eitc_pop_non_black_pres, non_eitc_pop_non_black_posts, non_eitc_pop_non_black_corrs, non_eitc_pop_non_black_ncorrs, non_eitc_pop_ns)), columns = ['black_any', 'black_pre', 'black_post', 'black_corr', 'black_ncorr', 'non_black_any', 'non_black_pre', 'non_black_post', 'non_black_corr', 'non_black_ncorr', 'n'])
    full_pop_disp_df.to_csv('/REDACTED/levels_table_lin_full_pop_update.csv', index=False)
    eitc_pop_disp_df.to_csv('/REDACTED/levels_table_lin_eitc_pop_update.csv', index=False)
    non_eitc_pop_disp_df.to_csv('/REDACTED/levels_table_lin_non_eitc_pop_update.csv', index=False)
    end = timer()
    print("Time: "+str(end-start))

# read in datasets created in loop
full_pop = pd.read_csv('/REDACTED/levels_table_lin_full_pop_update.csv')
eitc_pop = pd.read_csv('/REDACTED/levels_table_lin_eitc_pop_update.csv')
non_eitc_pop = pd.read_csv('/REDACTED/levels_table_lin_non_eitc_pop_update.csv')

# write out means and standard deviations to a txt file
sys.stdout = open("/REDACTED/levels_table_lin_update.txt", "w")

print("Black Full Population\n")
print("Any")
print("Mean")
full_pop.black_any.mean()
print("SE")
full_pop.black_any.std()
print("Pre")
print("Mean")
full_pop.black_pre.mean()
print("SE")
full_pop.black_pre.std()
print("Post")
print("Mean")
full_pop.black_post.mean()
print("SE")
full_pop.black_post.std()
print("Corr")
print("Mean")
full_pop.black_corr.mean()
print("SE")
full_pop.black_corr.std()
print("NCorr")
print("Mean")
full_pop.black_ncorr.mean()
print("SE")
full_pop.black_ncorr.std()

print("\nNon-Black Full Population\n")
print("Any")
print("Mean")
full_pop.non_black_any.mean()
print("SE")
full_pop.non_black_any.std()
print("Pre")
print("Mean")
full_pop.non_black_pre.mean()
print("SE")
full_pop.non_black_pre.std()
print("Post")
print("Mean")
full_pop.non_black_post.mean()
print("SE")
full_pop.non_black_post.std()
print("Corr")
print("Mean")
full_pop.non_black_corr.mean()
print("SE")
full_pop.non_black_corr.std()
print("NCorr")
print("Mean")
full_pop.non_black_ncorr.mean()
print("SE")
full_pop.non_black_ncorr.std()
print("N")
full_pop.n.mean()


print("\nBlack EITC Population\n")
print("Any")
print("Mean")
eitc_pop.black_any.mean()
print("SE")
eitc_pop.black_any.std()
print("Pre")
print("Mean")
eitc_pop.black_pre.mean()
print("SE")
eitc_pop.black_pre.std()
print("Post")
print("Mean")
eitc_pop.black_post.mean()
print("SE")
eitc_pop.black_post.std()
print("Corr")
print("Mean")
eitc_pop.black_corr.mean()
print("SE")
eitc_pop.black_corr.std()
print("NCorr")
print("Mean")
eitc_pop.black_ncorr.mean()
print("SE")
eitc_pop.black_ncorr.std()

print("\nNon-Black EITC Population\n")
print("Any")
print("Mean")
eitc_pop.non_black_any.mean()
print("SE")
eitc_pop.non_black_any.std()
print("Pre")
print("Mean")
eitc_pop.non_black_pre.mean()
print("SE")
eitc_pop.non_black_pre.std()
print("Post")
print("Mean")
eitc_pop.non_black_post.mean()
print("SE")
eitc_pop.non_black_post.std()
print("Corr")
print("Mean")
eitc_pop.non_black_corr.mean()
print("SE")
eitc_pop.non_black_corr.std()
print("NCorr")
print("Mean")
eitc_pop.non_black_ncorr.mean()
print("SE")
eitc_pop.non_black_ncorr.std()
print("N")
eitc_pop.n.mean()


print("\nBlack Non-EITC Population\n")
print("Any")
print("Mean")
non_eitc_pop.black_any.mean()
print("SE")
non_eitc_pop.black_any.std()
print("Pre")
print("Mean")
non_eitc_pop.black_pre.mean()
print("SE")
non_eitc_pop.black_pre.std()
print("Post")
print("Mean")
non_eitc_pop.black_post.mean()
print("SE")
non_eitc_pop.black_post.std()
print("Corr")
print("Mean")
non_eitc_pop.black_corr.mean()
print("SE")
non_eitc_pop.black_corr.std()
print("NCorr")
print("Mean")
non_eitc_pop.black_ncorr.mean()
print("SE")
non_eitc_pop.black_ncorr.std()

print("\nNon-Black Non-EITC Population\n")
print("Any")
print("Mean")
non_eitc_pop.non_black_any.mean()
print("SE")
non_eitc_pop.non_black_any.std()
print("Pre")
print("Mean")
non_eitc_pop.non_black_pre.mean()
print("SE")
non_eitc_pop.non_black_pre.std()
print("Post")
print("Mean")
non_eitc_pop.non_black_post.mean()
print("SE")
non_eitc_pop.non_black_post.std()
print("Corr")
print("Mean")
non_eitc_pop.non_black_corr.mean()
print("SE")
non_eitc_pop.non_black_corr.std()
print("NCorr")
print("Mean")
non_eitc_pop.non_black_ncorr.mean()
print("SE")
non_eitc_pop.non_black_ncorr.std()
print("N")
non_eitc_pop.n.mean()

sys.stdout = sys.__stdout__
